# -*- coding: UTF-8 -*-

import math
from typing import Dict, Any

from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl

import nni

from operator import itemgetter

__all__ = ['BestAccuracy', 'NNIreport']


class BestAccuracy(Callback):
    def __init__(self, top_k: int) -> None:
        super().__init__()
        self.best_accuracy = -math.inf
        self.best_epoch = -math.inf
        self.top_k = top_k

    def __str__(self):
        return "Best Top-%d is %.2f on epoch %d" % (self.top_k, self.best_accuracy, self.best_epoch)

    def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if not trainer.sanity_checking:
            epoch = trainer.current_epoch
            accuracy = trainer.callback_metrics.get('val/acc%d' % self.top_k).item()
            if accuracy > self.best_accuracy:
                self.best_accuracy = accuracy
                self.best_epoch = epoch
            pl_module.log('val_best_top%d_acc' % self.top_k, self.best_accuracy, on_epoch=True)

    def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]) -> dict:
        return {'best_accuracy@%d' % self.top_k: self.best_accuracy, 'best_epoch@%d' % self.top_k: self.best_epoch}

    def on_load_checkpoint(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]) -> None:
        self.best_accuracy = callback_state['best_accuracy@%d' % self.top_k]
        self.best_epoch = callback_state['best_epoch@%d' % self.top_k]

    def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        print(self)


class NNIreport(Callback):
    def __init__(self) -> None:
        super().__init__()
        self.best_accuracy = -math.inf

    def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if not trainer.sanity_checking:
            accuracy = trainer.callback_metrics.get('val/acc1').item()
            nni.report_intermediate_result(accuracy)
            accuracy = trainer.callback_metrics.get('val/acc1').item()
            self.best_accuracy = max(self.best_accuracy, accuracy)

    def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]) -> dict:
        return {'best_accuracy@1': self.best_accuracy}

    def on_load_checkpoint(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]) -> None:
        self.best_accuracy = callback_state['best_accuracy@1']

    def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        nni.report_final_result(self.best_accuracy)
